

#To reproduce the experiment, please make five files "landbird", "waterbird", "snake", "landturtle" and "waterturtle" in ..... (...... is  path name) .
#From ImageNet, take the images and put them into the above files as following manners:
# ruffed grouse,  indigo bunting ===> "landbird"
# albatross, water ouzel===> "waterbird"
# thunder snake, ringneck snake, gather snake ===> "snake"
# mud turtle,  box turtle ===> "landturtle"
# loggerhead , leatherback  turtle===> "waterturtle"
#After that, we can reproduce the experiment by  "python main.py" .


import argparse
import numpy as np
import torch
import pandas as pd

import os, csv
import argparse
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np


from torchvision import datasets
from torch import nn, optim, autograd
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from mpl_toolkits.mplot3d import Axes3D
from torch.utils.data.dataset import Subset
from sklearn.model_selection import KFold
from additional_functions import mean_nll,mean_nll2,mean_accuracy,prob_sum,softmax,Condi_MI,mean_nll2_forIRM


parser = argparse.ArgumentParser(description='Colored MNIST')
parser.add_argument('--hidden_dim', type=int, default=440)
parser.add_argument('--l2_regularizer_weight', type=float,default=0.001)
parser.add_argument('--lr', type=float, default=0.0004)
parser.add_argument('--steps', type=int, default=35)
flags = parser.parse_args()


scale = 256.0/224.
target_resolution = (224, 224)
minibatch = 128

saving_data = []

def condi_prob(x):
    return torch.stack([x[:,1]/(x[:,1] + x[:,2]),x[:,2]/(x[:,1] + x[:,2])]).T


transform = transforms.Compose([
            transforms.Resize((int(target_resolution[0]*scale), int(target_resolution[1]*scale))),
            transforms.CenterCrop(target_resolution),
            transforms.ToTensor()
        ])


directory = os.listdir('...../landbird')
randbird_images = []
for image_name in directory:
    #print(image_name)
    file_name = os.path.join('...../landbird', image_name)
    img = transform(Image.open(file_name).convert('RGB'))
    #print(img.shape)
    randbird_images.append(img)
    
    
directory = os.listdir('..../waterbird')
waterbird_images = []
for image_name in directory:
    #print(image_name)
    file_name = os.path.join('..../waterbird', image_name)
    img = transform(Image.open(file_name).convert('RGB'))
    #print(img.shape)
    waterbird_images.append(img)
    
    
directory = os.listdir('..../snake')
randsnake_images = []
for image_name in directory:
    #print(image_name)
    file_name = os.path.join('..../snake', image_name)
    img = transform(Image.open(file_name).convert('RGB'))
    #print(img.shape)
    randsnake_images.append(img)
    
    

    

directory = os.listdir('..../landturtle')
randturtle_images = []
for image_name in directory:
    #print(image_name)
    file_name = os.path.join('..../landturtle', image_name)
    img = transform(Image.open(file_name).convert('RGB'))
    #print(img.shape)
    randturtle_images.append(img)
    
    
directory = os.listdir('..../waterturtle')
waterturtle_images = []
for image_name in directory:
    #print(image_name)
    file_name = os.path.join('..../waterturtle', image_name)
    img = transform(Image.open(file_name).convert('RGB'))
    #print(img.shape)
    waterturtle_images.append(img)
   

randbird_images = torch.stack(randbird_images, dim=0)
waterbird_images= torch.stack(waterbird_images, dim=0)
snake_images = torch.stack(randsnake_images, dim=0)
randturtle_images = torch.stack(randturtle_images, dim=0)
waterturtle_images= torch.stack(waterturtle_images, dim=0)


env1 = {}
env2 = {}
env2['images'] = torch.cat([randbird_images,snake_images[::2,:,:,:],waterturtle_images])
env2['labels'] = torch.cat([torch.zeros(randbird_images.shape[0]),torch.ones(snake_images[::2,:,:,:].shape[0]),(torch.ones(waterturtle_images.shape[0])+1)]).long()
env1['images'] = torch.cat([waterbird_images, snake_images[1::2,:,:,:], randturtle_images])
env1['labels'] = torch.cat([torch.zeros(waterbird_images.shape[0]),torch.ones(snake_images[1::2,:,:,:].shape[0]),(torch.ones(randturtle_images.shape[0])+1)]).long()

torch.set_printoptions(edgeitems=100000)


rng_state = np.random.get_state(1)
np.random.shuffle(env1['images'].numpy())
np.random.set_state(rng_state)
np.random.shuffle(env1['labels'].numpy())


rng_state = np.random.get_state(2)
np.random.shuffle(env2['images'].numpy())
np.random.set_state(rng_state)
np.random.shuffle(env2['labels'].numpy())


env1_train = {}
env2_train = {}

env1_train['images'] = env1['images'][env1['images'].shape[0]//5:]
env1_train['labels'] = env1['labels'][env1['images'].shape[0]//5:]

env2_train['images'] = env2['images'][env2['images'].shape[0]//5:]
env2_train['labels'] = env2['labels'][env2['images'].shape[0]//5:]

print('train:',env1_train['images'].shape)

envs_train = [env1_train,env2_train]


env1_test = {}
env2_test = {}

env1_test['images'] = env1['images'][:env1['images'].shape[0]//5]
env1_test['labels'] = env1['labels'][:env1['images'].shape[0]//5]

env2_test['images'] = env2['images'][:env2['images'].shape[0]//5]
env2_test['labels'] = env2['labels'][:env2['images'].shape[0]//5]

print('test:',env1_test['labels'].shape)


envs_test = [env1_test,env2_test]

ratio  = (env2_train['labels'].view(-1,1))[env2_train['labels'].view(-1)!=0,:].shape[0]/env2_train['labels'].shape[0]


    
class Mydatasets_env1(torch.utils.data.Dataset):
    def __init__(self):
        self.data = env1_train['images']
        self.label = env1_train['labels'].view(-1,1)

        self.datanum =  env1_train['images'].shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        return out_data, out_label
    
    
class Mydatasets_env2(torch.utils.data.Dataset):
    def __init__(self):
        self.data = env2_train['images']
        self.label = env2_train['labels'].view(-1,1)

        self.datanum =  env2_train['images'].shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        return out_data, out_label
    
class Mydatasets_env1_test(torch.utils.data.Dataset):
    def __init__(self):
        self.data = env1_test['images']
        self.label = env1_test['labels'].view(-1,1)

        self.datanum =  env1_test['images'].shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        return out_data, out_label
    
    
class Mydatasets_env2_test(torch.utils.data.Dataset):
    def __init__(self):
        self.data = env2_test['images']
        self.label = env2_test['labels'].view(-1,1)

        self.datanum =  env2_test['images'].shape[0]

    def __len__(self):
        return self.datanum

    def __getitem__(self, idx):
        out_data = self.data[idx]
        out_label = self.label[idx]

        return out_data, out_label
    
    
    
trainset1= Mydatasets_env1()
trainset1_test= Mydatasets_env1_test()
trainset2= Mydatasets_env2()
trainset2_test= Mydatasets_env2_test()

import torch
import torch.nn as nn

class block(nn.Module):
    def __init__(self, first_conv_in_channels, first_conv_out_channels, identity_conv=None, stride=1):
    
        super(block, self).__init__()

        self.conv1 = nn.Conv2d(
            first_conv_in_channels, first_conv_out_channels, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(first_conv_out_channels)

        self.conv2 = nn.Conv2d(
            first_conv_out_channels, first_conv_out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn2 = nn.BatchNorm2d(first_conv_out_channels)

        self.conv3 = nn.Conv2d(
            first_conv_out_channels, first_conv_out_channels*4, kernel_size=1, stride=1, padding=0)
        self.bn3 = nn.BatchNorm2d(first_conv_out_channels*4)
        self.relu = nn.ReLU()

        self.identity_conv = identity_conv

    def forward(self, x):

        identity = x.clone()  

        x = self.conv1(x)  
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)  
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x) 
        x = self.bn3(x)

        if self.identity_conv is not None:
            identity = self.identity_conv(identity)
        x += identity

        x = self.relu(x)

        return x
    
    
class ResNet(nn.Module):
    def __init__(self, block):
        super(ResNet, self).__init__()


        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2_x = self._make_layer(block, 3, res_block_in_channels=64, first_conv_out_channels=64, stride=1)

        self.conv3_x = self._make_layer(block, 4, res_block_in_channels=256,  first_conv_out_channels=128, stride=2)
        self.conv4_x = self._make_layer(block, 6, res_block_in_channels=512,  first_conv_out_channels=256, stride=2)
        self.conv5_x = self._make_layer(block, 3, res_block_in_channels=1024, first_conv_out_channels=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*4, 256)
        self.fc1 = nn.Linear(256, 3)
        self.fc2 = nn.Linear(256, 1)
        
        
        self._main1 = nn.Sequential(self.fc1)
        self._main2 = nn.Sequential(self.fc2)

    def forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self._main1(x)
        
 
        return softmax(x)
    
    def high_forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self._main2 (x)

        return x
    def pre_forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        #x = self._main1(x)
        
         

        return x
    
    def condi_forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self._main1(x)
        x = softmax(x)
 
        return condi_prob(x)
    
    def sum_forward(self,x):

        x = self.conv1(x)   # in:(3,224*224)、out:(64,112*112)
        x = self.bn1(x)     # in:(64,112*112)、out:(64,112*112)
        x = self.relu(x)    # in:(64,112*112)、out:(64,112*112)
        x = self.maxpool(x) # in:(64,112*112)、out:(64,56*56)

        x = self.conv2_x(x)  # in:(64,56*56)  、out:(256,56*56)
        x = self.conv3_x(x)  # in:(256,56*56) 、out:(512,28*28)
        x = self.conv4_x(x)  # in:(512,28*28) 、out:(1024,14*14)
        x = self.conv5_x(x)  # in:(1024,14*14)、out:(2048,7*7)
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        x = self._main1(x)
        x = softmax(x)
 
        return prob_sum(x)
    

    def _make_layer(self, block, num_res_blocks, res_block_in_channels, first_conv_out_channels, stride):
        layers = []

        identity_conv = nn.Conv2d(res_block_in_channels, first_conv_out_channels*4, kernel_size=1,stride=stride)
        layers.append(block(res_block_in_channels, first_conv_out_channels, identity_conv, stride))

        in_channels = first_conv_out_channels*4

        for i in range(num_res_blocks - 1):
            layers.append(block(in_channels, first_conv_out_channels, identity_conv=None, stride=1))

        return nn.Sequential(*layers)

def pretty_print(*values):
    col_width = 13
    def format_val(v):
        if not isinstance(v, str):
            v = np.array2string(v, precision=5, floatmode='fixed')
        return v.ljust(col_width)
    str_values = [format_val(v) for v in values]
    print("   ".join(str_values))

def condi_prob(x):
    return torch.stack([x[:,1]/(x[:,1] + x[:,2]),x[:,2]/(x[:,1] + x[:,2])]).T


iters_array = np.array([10,20,30])
penalty_weight_array =  np.array([0,1,  10, 100])
saving_data = []
for iters in iters_array:
        for penalty_weight in penalty_weight_array:

            print('iters:',iters)
            print('penalty_weight:',penalty_weight)
            
            trainset1= Mydatasets_env1()
            trainset2= Mydatasets_env2()


            final_train_accs = []
            final_test_accs_env1 = []
            final_test_accs_env2 = []
            sbs_CV_store = []
            max_CV_store = []
            print('CV training')
            
            kf = KFold(n_splits=5)
            for _fold, (train_index, valid_index) in enumerate(kf.split(trainset1)):
                print('CV_step=',_fold)
                model_CV = ResNet(block).cuda()
                

                
                train1_dataset  = Subset(trainset1, train_index)
                train2_dataset  = Subset(trainset2, train_index)
                
                

                
                
                
                trainloader1 = torch.utils.data.DataLoader(train1_dataset  , batch_size =56, shuffle = True, num_workers=6)
                trainloader2 = torch.utils.data.DataLoader(train2_dataset, batch_size = 56, shuffle = True, num_workers=6)

                def penalty(logits, y):
                    loss = mean_nll2_forIRM(logits , y)
                    grad = autograd.grad(loss, model_CV._main2[0].parameters(), create_graph=True)[0]
                    return torch.sum(grad**2)



                optimizer = optim.Adam(model_CV.parameters(), lr=flags.lr)


                pretty_print('step', 'train nll', 'train acc', 'train penalty','w*train penalty')

                for step in range(flags.steps):
                
                    train_acc_store = []
                    train_nll_store = []
                    penalty_store = []
                    
                    for (batch1, batch2) in zip(trainloader1, trainloader2):
                            batch1 = tuple(t.cuda() for t in batch1)
                            batch2 = tuple(t.cuda() for t in batch2)

                            images1 = batch1[0]
                            labels1 = batch1[1]

                            images2 = batch2[0]
                            labels2 = batch2[1]       
                            pre_logits1 = model_CV.pre_forward(images1)
                            pre_logits2 = model_CV.pre_forward(images2)
                            train_nll = mean_nll(torch.log(softmax(model_CV._main1(pre_logits1))) ,labels1)


                            train_acc = mean_accuracy(softmax(model_CV._main1(pre_logits1)), labels1)

                          
                            penalty1 = penalty( model_CV._main2(pre_logits1), (labels1 >0 ).float().view(-1,1) )
                            penalty2 = penalty( model_CV._main2(pre_logits2), (labels2 >0 ).float().view(-1,1) )

                            print('train_nll:',train_nll)

                            train_penalty = torch.stack([penalty1, penalty2]).mean()

                            loss = train_nll.clone()
      
                            weight_norm = torch.tensor(0.).cuda()
                            for w in model_CV.parameters():
                                weight_norm += w.norm().pow(2)
                            loss += flags.l2_regularizer_weight * weight_norm
                            penalty_weights = ( penalty_weight
                                                                if step >=iters else 0.1 )
                            loss += penalty_weights * train_penalty
                            if penalty_weights > 1.0:
 
                                    loss /= penalty_weights

                            train_nll_store.append(train_nll.detach().cpu())
                            train_acc_store.append(train_acc.detach().cpu())
                            penalty_store.append(train_penalty.detach().cpu())

 

                            optimizer.zero_grad()
                            loss.backward()
                            del images1
                            del images2
                            del labels1
                            del labels2
                            del train_nll
                            del train_acc
                            del batch1
                            del batch2


                            del pre_logits1
                            del pre_logits2 
                            del penalty1
                            del penalty2
                            del loss
                            
                            optimizer.step()

                    train_acc = torch.stack(train_acc_store, dim=0).mean()                      
                    train_nll = torch.stack(train_nll_store, dim=0).mean()

                    train_penalty = torch.stack(penalty_store, dim=0).mean()
                    if step % 1 == 0:   
                        pretty_print(
                        np.int32(step),
                        train_nll.detach().cpu().numpy() ,
                        train_acc.detach().cpu().numpy() ,
                        train_penalty.detach().cpu().numpy(),
                         (train_penalty*penalty_weights).detach().cpu().numpy() )
                            
                final_train_accs.append(train_acc.detach().cpu().numpy())
                print('Final train acc (mean/std across restarts so far):')
                print(np.mean(final_train_accs), np.std(final_train_accs))
                
                
                del train_acc_store
                del train_nll_store
                del penalty_store
                
                del train1_dataset 
                del train2_dataset  
                              
                
                del trainloader1 
                del trainloader2 

                
                out_nll1_store = []
                out_nll2_store = []
                out_bias_store = []
                
                train1_dataset  = Subset(trainset1, valid_index)
                train2_dataset  = Subset(trainset2, valid_index)
                
                
                trainloader1  = torch.utils.data.DataLoader(train1_dataset, batch_size = 14, shuffle = True, num_workers=6)
                trainloader2  = torch.utils.data.DataLoader(train2_dataset, batch_size = 14, shuffle = True, num_workers=6)
                
                for (batch1, batch2) in zip(trainloader1,trainloader2):
                        batch1 = tuple(t.cuda() for t in batch1)                        
                        images1 = batch1[0]
                        labels1 = batch1[1]

                        logits1 = model_CV.forward(images1)
                       

                        out_nll1 = mean_nll(torch.log(logits1) ,labels1)
                        out_nll1_store.append( out_nll1.detach().cpu() )
                        print('out_nll1 :', out_nll1) 
                        
                        
                        logits1_bias = model_CV.condi_forward(images1[labels1.view(-1)!=0,:])                       
                        out_bias = mean_nll(torch.log(logits1_bias) ,labels1[labels1.view(-1)!=0,:] -1. )
                        
                        print('out_bias :', out_bias)                       
                        out_bias_store.append(out_bias.detach().cpu())                 
                        
                        del images1
                        del labels1 
                        del logits1
                        del out_nll1
                        del batch1
                        del logits1_bias 
                        del out_bias
           
                        batch2 = tuple(t.cuda() for t in batch2)
                        images2 = batch2[0]
                        labels2 = batch2[1]            
                        logits2 = model_CV.sum_forward(images2)
                        out_nll2 =  mean_nll( torch.log(logits2), (labels2 >0 ).float().view(-1,1)  )
                        out_nll2_store.append(out_nll2.detach().cpu())
                        
                        print('out_nll2 :', out_nll2) 
                        
                        del images2
                        del labels2
                        del logits2
                        del out_nll2
                        del batch2                                                                                                        
                                                                                                         
                
                out_nll1 = torch.stack(out_nll1_store, dim=0).mean()
                out_nll2 = torch.stack(out_nll2_store, dim=0).mean()
                out_bias = torch.stack(out_bias_store, dim=0).mean()

                sbs_CV_store.append( torch.max(torch.cat([out_nll1.view(-1), (out_nll2 + out_bias*ratio).view(-1)  ]  ) ) )
                max_CV_store.append( torch.max(torch.cat( [ out_nll1.view(-1), out_nll2.view(-1) ] ) ) )
                print('each_training_loss_with_correction:',torch.max(torch.cat([out_nll1.view(-1), (out_nll2 + out_bias*ratio).view(-1)  ]  ) )  )
                print('each_training_loss:',  torch.cat( [ out_nll1.view(-1), out_nll2.view(-1) ] ) )
  
                del out_nll1
                del out_nll2
                del out_bias
                
                del train1_dataset  
                del train2_dataset  
                
                
                del trainloader1  
                del trainloader2  
                

            print('final_sbs:', np.mean(sbs_CV_store))
            print('final_max:', np.mean(max_CV_store))
            
            
                



                        
            print('start training')
            
            trainset1= Mydatasets_env1()
            trainset1_test= Mydatasets_env1_test()
            trainset2= Mydatasets_env2()
            trainset2_test= Mydatasets_env2_test()

            

            final_train_accs = []
            final_test_accs_env1 = []
            final_test_accs_env2 = []

            model = ResNet(block).cuda()

            trainloader1 = torch.utils.data.DataLoader(trainset1, batch_size =56, shuffle = True, num_workers=8)
            trainloader1_test = torch.utils.data.DataLoader(trainset1_test, batch_size = 50, shuffle = True, num_workers=4)
            trainloader2 = torch.utils.data.DataLoader(trainset2, batch_size = 56, shuffle = True, num_workers=4)
            trainloader2_test = torch.utils.data.DataLoader(trainset2_test, batch_size = 50, shuffle = True, num_workers=8)

            def penalty(logits, y):
                loss = mean_nll2_forIRM(logits , y)
                grad = autograd.grad(loss, model._main2[0].parameters(), create_graph=True)[0]
                return torch.sum(grad**2)

            optimizer = optim.Adam(model.parameters(), lr=flags.lr)
            
            for step in range(flags.steps):

                    train_acc_store = []
                    train_nll_store = []
                    labels_store = []
                    test1_acc_store = []
                    test2_acc_store = []
                    test1_acc_label0_store = []
                    test1_acc_label1_store = []
                    test1_acc_label2_store = []
                    test2_acc_label0_store = []
                    test2_acc_label1_store = []
                    test2_acc_label2_store = []
                    penalty_store = []

                    for (batch1, batch2) in zip(trainloader1, trainloader2):
                            batch1 = tuple(t.cuda() for t in batch1)
                            batch2 = tuple(t.cuda() for t in batch2)

                            images1 = batch1[0]
                            labels1 = batch1[1]

                            images2 = batch2[0]
                            labels2 = batch2[1]       
                            pre_logits1 = model.pre_forward(images1)
                            pre_logits2 = model.pre_forward(images2)
                            train_nll = mean_nll(torch.log(softmax(model._main1(pre_logits1))) ,labels1)


                            train_acc = mean_accuracy(softmax(model._main1(pre_logits1)), labels1)

                            penalty1 = penalty( model._main2(pre_logits1), (labels1 >0 ).float().view(-1,1) )
                            penalty2 = penalty( model._main2(pre_logits2), (labels2 >0 ).float().view(-1,1) )

                            print('train_nll:',train_nll)

                            train_penalty = torch.stack([penalty1, penalty2]).mean()

                            loss = train_nll.clone()

                            weight_norm = torch.tensor(0.).cuda()
                            for w in model.parameters():
                                weight_norm += w.norm().pow(2)
                            loss += flags.l2_regularizer_weight * weight_norm
                            penalty_weights = ( penalty_weight
                                                                if step >=iters else 0.1 )
                            loss += penalty_weights * train_penalty
                            if penalty_weights > 1.0:
                                  # Rescale the entire loss to keep gradients in a reasonable range
                                    loss /= penalty_weights

                            train_nll_store.append(train_nll.detach().cpu())
                            train_acc_store.append(train_acc.detach().cpu())


                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()
 
                            del batch1
                            del batch2
                            del images1
                            del images2
                            del pre_logits1
                            del pre_logits2
                            del train_nll
                            del train_acc
                            del train_penalty
                            del loss
                            del labels1
                            del labels2
                            del weight_norm
                            del penalty1
                            del penalty2


                    for (batch1, batch2) in zip(trainloader1_test,trainloader2_test):
                            batch1 = tuple(t.cuda() for t in batch1)
                            batch2 = tuple(t.cuda() for t in batch2)



                            images1 = batch1[0]
                            labels1 = batch1[1]

                            images2 = batch2[0]
                            labels2 = batch2[1]



                            logits1 = model.forward(images1)
                            logits2 = model.forward(images2)
                            test1_acc = mean_accuracy(logits1, labels1)
                            test2_acc = mean_accuracy(logits2, labels2)

                            test2_acc_label0= mean_accuracy(logits2[labels2.view(-1)==0,:], labels2[labels2.view(-1)==0,:])
                            test2_acc_label1= mean_accuracy(logits2[labels2.view(-1)==1,:], labels2[labels2.view(-1)==1,:])
                            test2_acc_label2= mean_accuracy(logits2[labels2.view(-1)==2,:], labels2[labels2.view(-1)==2,:])

                            print('test1_acc:',test1_acc)
                            print('test2_acc:',test2_acc)


                            test1_acc_store.append(test1_acc.detach().cpu())
                            test2_acc_store.append(test2_acc.detach().cpu())
                            test2_acc_label0_store.append(test2_acc_label0.detach().cpu())
                            test2_acc_label1_store.append(test2_acc_label1.detach().cpu())
                            test2_acc_label2_store.append(test2_acc_label2.detach().cpu())



                    train_acc = torch.stack(train_acc_store, dim=0).mean()
                    test_acc_env1 = torch.stack(test1_acc_store, dim=0).mean()
                    test_acc_env2 = torch.stack(test2_acc_store, dim=0).mean()
                    test_acc_env2_label0 = torch.stack(test2_acc_label0_store, dim=0).mean()
                    test_acc_env2_label1 = torch.stack(test2_acc_label1_store, dim=0).mean()
                    test_acc_env2_label2 = torch.stack(test2_acc_label2_store, dim=0).mean()
                    train_nll = torch.stack(train_nll_store, dim=0).mean()

                    if step % 1 == 0:   
                        pretty_print('step', 'train nll', 'train acc', 'test1 acc', 'test2_acc',  'test1_label0', 'test1_label1',  'test1_label2', 'test2_label0', 'test2_label1',  'test2_label2')
                        pretty_print(
                        np.int32(step),
                        train_nll.detach().cpu().numpy() ,
                        train_acc.detach().cpu().numpy() ,
                        test_acc_env1.detach().cpu().numpy() ,
                        test_acc_env2.detach().cpu().numpy()  ,
                        test_acc_env2_label0.detach().cpu().numpy()   ,
                        test_acc_env2_label1.detach().cpu().numpy()  ,
                        test_acc_env2_label2.detach().cpu().numpy()  )

            final_train_accs.append(train_acc.detach().cpu().numpy())
            final_test_accs_env1.append(test_acc_env1.detach().cpu().numpy())
            final_test_accs_env2.append(test_acc_env2.detach().cpu().numpy())
            print('Final train acc (mean/std across restarts so far):')
            print(np.mean(final_train_accs), np.std(final_train_accs))
            print('Final test acc env1(mean/std across restarts so far):')
            print(np.mean(final_test_accs_env1), np.std(final_test_accs_env1))
            print('Final test acc env2(mean/std across restarts so far):')
            print(np.mean(final_test_accs_env2), np.std(final_test_accs_env2))
                                                     
            saving_data.append([iters, penalty_weight, np.mean(sbs_CV_store), np.mean(max_CV_store), np.mean(final_test_accs_env1), np.std(final_test_accs_env1), np.mean(final_test_accs_env2),  np.std(final_test_accs_env2)])
sample = pd.DataFrame(saving_data, columns=['iters','penalty_weight1','CVII','CVI', 'test1_acc','test1_std', 'test2_acc','test2_std'])
print(sample)  

sample.to_csv('CVresult.csv')



